{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfakes with GANs\n",
    "> Re-enactment using Pix2Pix\n",
    "\n",
    "<img src=\"deepfake_banner.png\">\n",
    "\n",
    "We covered image-to-image translation GAN architectures in Chapter 5. Particularly, we discussed in detail how **pix2pix GAN** is a powerful architecture which enables paired translation tasks. In this notebook, we will leverage pix2pix GAN to develop a face re-enactment setup from scratch. We will:\n",
    "+ build a pix2pix network\n",
    "+ prepare the dataset using a video\n",
    "+ train the model for reenactment using facial landmarks\n",
    "\n",
    "The actual reenactment part is covered in the second notebook for this chapter. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N6FXnJHk32ND"
   },
   "source": [
    "## Load Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9KlERwzUJbNr"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import dlib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.utils import save_image\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gan_utils import PATCH_GAN_SHAPE\n",
    "from gan_utils import Generator,Discriminator \n",
    "from gan_utils import (IMG_WIDTH,\n",
    "                        IMG_HEIGHT,\n",
    "                        NUM_CHANNELS,\n",
    "                        BATCH_SIZE,\n",
    "                        N_EPOCHS,\n",
    "                        SAMPLE_INTERVAL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils import ImageDataset, prepare_data\n",
    "from dataset_utils import DATASET_PATH, DOWNSAMPLE_RATIO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_T7bKz_T5EP4"
   },
   "source": [
    "## Set Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-3M00EI0suiY"
   },
   "outputs": [],
   "source": [
    "CUDA = True if torch.cuda.is_available() else False\n",
    "os.makedirs(\"saved_models/\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_x6XGNalsuVz",
    "outputId": "f079c7c3-461b-4a9f-fa83-8ba81ca41251"
   },
   "outputs": [],
   "source": [
    "# get landmarks model if not already available\n",
    "!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\n",
    "!bunzip2 \"shape_predictor_68_face_landmarks.dat.bz2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PDFOoTQks6JM"
   },
   "outputs": [],
   "source": [
    "# instantiate objects for face and landmark detection\n",
    "detector = dlib.get_frontal_face_detector()\n",
    "predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mEWl8xxU3sKu"
   },
   "source": [
    "# Pix2Pix GAN for Re-enactment\n",
    "\n",
    "In their work titled [“Image to Image Translation with Conditional Adversarial Networks”](https://arxiv.org/abs/1611.07004), Isola and Zhu et. al. present a conditional GAN network which is able to learn task specific loss functions and thus work across datasets. As the name suggests, this GAN architecture takes a specific type of image as input and transforms it into a different domain. It is called pair-wise style transfer as the training set needs to have samples from both, source and target domains."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "E5hkGY1V4OG8"
   },
   "source": [
    "## U-Net Generator\n",
    "The U-Net architecture uses skip connections to shuttle important features between the input and outputs. In case of pix2pix GAN, skip connections are added between every $ith$ down-sampling and $(n-i)th$ over-sampling layers, where $n$ is the total number of layers in the generator. The skip connection leads to concatenation of all channels from the ith and $(n-i)th$ layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d5gQ7yrG4Uzp"
   },
   "source": [
    "## Patch-GAN Discriminator\n",
    "The authors for pix2pix propose a Patch-GAN setup for the discriminator which takes the required inputs and generates an output of size NxN. Each $x_{ij}$ element of the NxN output signifies whether the corresponding patch ij in the generated image is real or fake. Each output patch can be traced back to its initial input patch basis the effective receptive field for each of the layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FbXMgH8_5Pyp"
   },
   "source": [
    "## Initialize Generator and Discriminator Model Objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CFJbgyHeRlFh"
   },
   "outputs": [],
   "source": [
    "# Initialize generator and discriminator\n",
    "generator = Generator()\n",
    "discriminator = Discriminator()\n",
    "\n",
    "# Loss functions\n",
    "adversarial_loss = torch.nn.MSELoss()\n",
    "pixelwise_loss = torch.nn.L1Loss()\n",
    "\n",
    "# Loss weight of L1 pixel-wise loss between translated image and real image\n",
    "weight_pixel_wise_identity = 100\n",
    "\n",
    "# Optimizers\n",
    "optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
    "optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2bTJhEZG5brD"
   },
   "source": [
    "## Prepare Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "plTpZOWDs6GD",
    "outputId": "a7017603-2471-4dbf-dcab-f0b400b95cc3"
   },
   "outputs": [],
   "source": [
    "# prepare data\n",
    "prepare_data('obama.mp4',\n",
    "             detector,\n",
    "             predictor,\n",
    "             num_samples=400,\n",
    "             downsample_ratio = DOWNSAMPLE_RATIO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Objects based on GPU Availability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6dr3Jr7ySNJ4"
   },
   "outputs": [],
   "source": [
    "if CUDA:\n",
    "    generator = generator.cuda()\n",
    "    discriminator = discriminator.cuda()\n",
    "    adversarial_loss.cuda()\n",
    "    pixelwise_loss.cuda()\n",
    "    Tensor = torch.cuda.FloatTensor\n",
    "else:\n",
    "  Tensor = torch.FloatTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VvRfehAN5o8I"
   },
   "source": [
    "## Define Transformations and Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DjbEdO81SNHa"
   },
   "outputs": [],
   "source": [
    "image_transformations = [\n",
    "    transforms.Resize((IMG_HEIGHT, IMG_WIDTH), Image.BICUBIC),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lv0x9dX_SNEs"
   },
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(\n",
    "    ImageDataset(DATASET_PATH, image_transformations=image_transformations),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3cA2eGNXSNBz"
   },
   "outputs": [],
   "source": [
    "val_dataloader = DataLoader(\n",
    "    ImageDataset(DATASET_PATH,image_transformations=image_transformations),\n",
    "    batch_size=BATCH_SIZE//8,\n",
    "    shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sfCtlMsj5tTV"
   },
   "source": [
    "## Training Begins!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_images(val_dataloader,batches_done):\n",
    "    \"\"\"\n",
    "        Method to generate sample images for validation\n",
    "        Parameters:\n",
    "            val_dataloader: instance of dataloader\n",
    "            batches_done: training iteration counter\n",
    "    \"\"\"\n",
    "    imgs = next(iter(val_dataloader))\n",
    "    # condition\n",
    "    real_A = Variable(imgs[\"B\"].type(Tensor))\n",
    "    # real\n",
    "    real_B = Variable(imgs[\"A\"].type(Tensor))\n",
    "    # generated\n",
    "    generator.eval()\n",
    "    fake_B = generator(real_A)\n",
    "    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)\n",
    "    save_image(img_sample, f\"{DATASET_PATH}/{batches_done}.png\", nrow=4, normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "41d5MJCGZoZv",
    "outputId": "05b6e281-ada8-4244-d747-cc8fd319ecde"
   },
   "outputs": [],
   "source": [
    "for epoch in range(0, N_EPOCHS):\n",
    "    for i, batch in enumerate(train_dataloader):\n",
    "\n",
    "        # prepare inputs\n",
    "        real_A = Variable(batch[\"B\"].type(Tensor))\n",
    "        real_B = Variable(batch[\"A\"].type(Tensor))\n",
    "\n",
    "        # ground truth\n",
    "        valid = Variable(Tensor(np.ones((real_A.size(0), *PATCH_GAN_SHAPE))), requires_grad=False)\n",
    "        fake = Variable(Tensor(np.zeros((real_A.size(0), *PATCH_GAN_SHAPE))), requires_grad=False)\n",
    "\n",
    "        #  Train Generator\n",
    "        optimizer_G.zero_grad()\n",
    "\n",
    "        # generator loss\n",
    "        fake_B = generator(real_A)\n",
    "        pred_fake = discriminator(fake_B, real_A)\n",
    "        adv_loss = adversarial_loss(pred_fake, valid)\n",
    "        loss_pixel = pixelwise_loss(fake_B, real_B)\n",
    "\n",
    "        # Overall Generator loss\n",
    "        g_loss = adv_loss + weight_pixel_wise_identity * loss_pixel\n",
    "\n",
    "        g_loss.backward()\n",
    "\n",
    "        optimizer_G.step()\n",
    "\n",
    "        #  Train Discriminator\n",
    "        optimizer_D.zero_grad()\n",
    "\n",
    "        pred_real = discriminator(real_B, real_A)\n",
    "        loss_real = adversarial_loss(pred_real, valid)\n",
    "        pred_fake = discriminator(fake_B.detach(), real_A)\n",
    "        loss_fake = adversarial_loss(pred_fake, fake)\n",
    "\n",
    "        # Overall Discriminator loss\n",
    "        d_loss = 0.5 * (loss_real + loss_fake)\n",
    "\n",
    "        d_loss.backward()\n",
    "        optimizer_D.step()\n",
    "\n",
    "        # Progress Report\n",
    "        batches_done = epoch * len(train_dataloader) + i\n",
    "        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(train_dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}--Adv.Loss:{adv_loss.item():.4f}')\n",
    "\n",
    "        # generate samples\n",
    "        if batches_done % SAMPLE_INTERVAL == 0:\n",
    "            sample_images(val_dataloader,batches_done)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the Trained Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sIzQFLMrajAw"
   },
   "outputs": [],
   "source": [
    "torch.save(generator.state_dict(), \"saved_models/generator.pt\")\n",
    "torch.save(discriminator.state_dict(), \"saved_models/discriminator.pt\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
